# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import dataclasses
import itertools
import math

import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
import numpy as np

from . import fragmented_array as fa
from . import mma_utils
from . import utils


c = utils.c
bytewidth = utils.bytewidth


@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class WGMMAAccumulator:
  """A FragmentedArray that has is synchronized with the async proxy.

  This implies that it requires no additional synchronization when passed in
  as a WGMMA accumulator. In particular, when created from a
  FragmentedArray, the necessary synchronization is inserted at construction.
  """
  _original_layout: fa.FragmentedLayout
  _value: fa.FragmentedArray

  def __init__(
      self,
      *,
      _value: fa.FragmentedArray,
      _original_layout: fa.FragmentedLayout,
      _sync: bool = True,
  ):
    self._original_layout = _original_layout
    self._value = _value
    if _sync:
      self._value = wgmma_fence(_value)

  @property
  def value(self) -> fa.FragmentedArray:
    return self._value.to_layout(self._original_layout)

  @classmethod
  def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
    if m % 64 or n % 8:
      raise ValueError
    if is_signed is False:  # pylint: disable=g-bool-id-comparison
      raise TypeError("PTX does not support unsigned WGMMA accumulators")
    f32 = ir.F32Type.get()
    if dtype is None:
      dtype = f32
    if ir.IntegerType.isinstance(dtype):
      zero = arith.constant(dtype, ir.IntegerAttr.get(dtype, 0))
    else:
      zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
    return cls.from_registers(
        fa.FragmentedArray.splat(
            zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed
        )
    )

  @classmethod
  def from_registers(cls, registers):
    original_layout = registers.layout
    if registers.layout != fa.WGMMA_LAYOUT and registers.layout != fa.WGMMA_LAYOUT_ACC_32BIT:
      raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
    if utils.bitwidth(registers.mlir_dtype) == 32:
      registers = registers.to_layout(fa.WGMMA_LAYOUT_ACC_32BIT)
    return cls(_value=registers, _original_layout=original_layout)

  def tree_flatten(self):
    return (self._value,), (self._original_layout,)

  @classmethod
  def tree_unflatten(cls, aux, value):
    return cls(_value=value[0], _original_layout=aux[0], _sync=False)


def _supported_wgmma_types(dtype, abtype) -> bool:
  input_types_are = lambda ty: ty.isinstance(abtype)
  f16_acc_types = (ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType)
  if ir.F32Type.isinstance(dtype):
    return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types))
  elif ir.F16Type.isinstance(dtype):
    return any(input_types_are(ty) for ty in f16_acc_types)
  elif ir.IntegerType.get_signless(32).isinstance(dtype):
    return input_types_are(ir.IntegerType.get_signless(8))
  else:
    return False


def wgmma_m64(
    acc: np.ndarray,  # of register Values
    a,
    b_descriptor: ir.Value,
    a_transpose: bool | None,
    b_transpose: bool,
    a_k_stride: int | None,
    b_k_stride: int,
    n: int,
    swizzle: int,
    element_type: ir.Type,
):
  out_ty = ir.VectorType(acc.flat[0].type).element_type
  if not _supported_wgmma_types(out_ty, element_type):
    raise ValueError(f"Unsupported wgmma types {(out_ty, element_type)=}")
  if n % 8:
    raise ValueError

  bf16 = ir.BF16Type.get()
  f16 = ir.F16Type.get()
  i8 = ir.IntegerType.get_signless(8)
  i32 = ir.IntegerType.get_signless(32)
  i64 = ir.IntegerType.get_signless(64)
  f8e5m2 = ir.Float8E5M2Type.get()
  f8e4m3fn = ir.Float8E4M3FNType.get()
  if b_k_stride % 16:
    raise ValueError
  # Only 16-bit types support transposes
  supports_transpose = bytewidth(element_type) == 2
  if not supports_transpose and (a_transpose or b_transpose):
    raise ValueError("Only f16 WGMMA supports transposes")
  if a_in_regs := isinstance(a, fa.FragmentedArray):
    if a.mlir_dtype not in {bf16, f16, i8, f8e5m2, f8e4m3fn}:
      raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
    # Column count must be equal to swizzle // bytewidth.
    elt_bytewidth = utils.bytewidth(element_type)
    swizzle_elems = swizzle // elt_bytewidth
    if a.shape != (64, swizzle_elems):
      raise ValueError("Unsupported A register array shape")
    if a.layout not in {fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_8BIT}:
      raise ValueError("Unsupported A register array layout")
    if a_k_stride is not None or a_transpose is not None:
      raise ValueError("Unsupported WGMMA features with A in registers")
  else:
    if a_k_stride is None or a_k_stride % 16:
      raise ValueError
    if a_transpose is None:
      raise ValueError

  if ir.F32Type.isinstance(out_ty) or out_ty == i32:
    num_acc_regs = n // 2
    out_ty_field = ir.VectorType.get((1,), out_ty)
    acc_regs = list(acc.flat)
    assert acc_regs[0].type == ir.VectorType.get((1,), out_ty)
    to_acc_vec_regs = lambda regs: np.array(regs).reshape(acc.shape)
    acc_constraint = "r" if ir.IntegerType.isinstance(out_ty) else "f"
  elif ir.F16Type.isinstance(out_ty):
    num_acc_regs = n // 4
    out_ty_field = i32
    acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
    vec_ty = ir.VectorType(acc.flat[0].type)
    to_acc_vec_regs = lambda regs: np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
    acc_constraint = "r"
  else:
    raise ValueError(
        f"WGMMA instruction only supports f32, f16 and s32 out (got {out_ty})")

  if supports_transpose:
    num_imm_regs = 4
  elif out_ty == i32:
    num_imm_regs = 0
  else:
    num_imm_regs = 2

  if a_in_regs:
    a_reg_constraints = ["r"] * 4  # 4x (b)f16x2 or s8x4 registers
    if supports_transpose:
      num_imm_regs -= 1  # transpose not supported for a in registers
  else:
    a_reg_constraints = ["l"]  # descriptor
  # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
  # Seems like it's not actually documented in LLVM IR docs.
  reg_constraints_list = (
      [f"={acc_constraint}"] * num_acc_regs  # accumulator registers
      + [str(i) for i in range(num_acc_regs)]  # we alias outputs as inputs, too.
      + a_reg_constraints  # a descriptor / registers
      + ["l"] * 1  # b descriptor
      + ["n"] * (1 + num_imm_regs)  # literal constants
  )
  reg_constraints = ",".join(reg_constraints_list)
  reg_count = itertools.count()

  def take_regs(n):
    return (f"${i}" for i in itertools.islice(reg_count, n))

  acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
  for _ in take_regs(num_acc_regs):  # Ignore next entries: aliasing.
    pass
  if a_in_regs:
    a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
  else:
    a_regs, = take_regs(1)
  b_desc_reg, use_out_reg = take_regs(2)
  # Immediate regs (scale, ...).
  imm_regs = "".join(f", {r}" for r in take_regs(num_imm_regs))
  assert next(reg_count) == len(reg_constraints_list)
  k_instr = 32 // bytewidth(element_type)
  el_ty = str(element_type)
  if ir.Float8E5M2Type.isinstance(element_type):
    el_ty = "e5m2"
  elif ir.Float8E4M3FNType.isinstance(element_type):
    el_ty = "e4m3"
  elif ir.IntegerType.get_signless(8).isinstance(element_type):
    # TODO(bchetioui): add u8 support in the future. Currently we always assume
    # that 8-bit integers are s8, and we would need to change the signature of
    # `wgmma` to indicate whether the input should be treated as signed or not.
    el_ty = "s8"

  out_ty_str = str(out_ty)
  if out_ty == i32:
    out_ty_str = "s32"

  wgmma_instr = (
      f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty_str}.{el_ty}.{el_ty} "
      f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p{imm_regs};"
  )
  ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"

  def lc(x):
    return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result

  use_out = scale_a = scale_b = lc(1)
  if out_ty == i32:
    imms = [use_out]
  else:
    imms = [use_out, scale_a, scale_b]

  if supports_transpose and a_transpose is not None:
    imms += [lc(int(a_transpose)), lc(int(b_transpose))]
  elif supports_transpose:
    imms += [lc(int(b_transpose))]

  assert len(imms) == num_imm_regs + 1  # +1 for the use_out_reg in setp.ne.b32

  expected_dim = 10 if utils.bitwidth(out_ty) == 32 else 9
  expected_regs_per_tile = 4 if utils.bitwidth(out_ty) == 32 else 2
  if acc.ndim != expected_dim or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != expected_regs_per_tile:
    raise ValueError(acc.shape)
  acc_struct_type = ir.Type.parse(
      f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
  )
  for i in range((swizzle // bytewidth(element_type)) // k_instr):
    # Slice out the relevant part of A or advance the A descriptor.
    if a_in_regs:
      a_slice = a[:, (i * k_instr) : ((i + 1) * k_instr)]
      a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
    else:
      if i > 0:
        assert a_k_stride is not None
        a = _llvm_add(
            a,
            llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
        )
      a_args = [a]
    # Advance the B descriptor.
    if i > 0:
      b_descriptor = _llvm_add(
          b_descriptor,
          llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
      )
    assert len(a_args) == len(a_reg_constraints)
    acc_struct = llvm.inline_asm(
        acc_struct_type,
        [*acc_regs, *a_args, b_descriptor, *imms],
        ptx,
        reg_constraints,
        asm_dialect=0,
        has_side_effects=True,
    )
    acc_regs = [
        llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
    ]
  return to_acc_vec_regs(acc_regs)


def wgmma(
    acc: WGMMAAccumulator,
    a: fa.FragmentedArray | ir.Value,
    b: ir.Value,
    *,
    swizzle: int = 128,
):
  """Perform acc += a @ b using the WGMMA instruction.

  `a` may be passed in registers, or as a memref. `b` must be a memref.

  The expected (logical) memref shapes are:
    a: (m // tile_m, k // tile_k, tile_m, tile_k)
    b: (k // tile_k, n // tile_n, tile_k, tile_n).

  While the shapes may be physically transposed, when considering the row-major
  physical shape, the tile dimensions must be the two minor dimensions and must
  have the shape (8, S) where S = swizzle // bytewidth(element_type).
  """
  if swizzle == 16:
    raise NotImplementedError("No swizzle is not supported")
  # Step 1. Establish the shape and element type of the operation.
  if not ir.MemRefType.isinstance(b.type):
    raise ValueError(f"B must be a memref, got: {b.type}")
  bf16 = ir.BF16Type.get()
  f32 = ir.F32Type.get()
  f16 = ir.F16Type.get()
  i32 = ir.IntegerType.get_signless(32)
  i8 = ir.IntegerType.get_signless(8)
  f8e5m2 = ir.Float8E5M2Type.get()
  f8e4m3fn = ir.Float8E4M3FNType.get()
  (k, n), element_type = mma_utils.tiled_memref_shape(b)
  if a_in_regs := isinstance(a, fa.FragmentedArray):
    m, k2 = a.shape
    element_type2 = a.mlir_dtype
    if element_type2 not in {f16, bf16, i8, f8e5m2, f8e4m3fn}:
      raise ValueError(
          "Only f16, bf16, i8, f8e5m2, f8e4m3fn are supported for A "
          f"in registers, got {element_type2}"
      )
    if element_type2 == i8 and swizzle == 32:
      # TODO(bchetioui): relax this when ptxas is fixed. As of ptxas 12.8,
      # optimizations eliminate MMA instructions, leading to only the first tile
      # of the result being computed correctly.
      raise NotImplementedError("swizzle=32 not supported for s8 lhs in registers")
  elif ir.MemRefType.isinstance(a.type):
    (m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
  else:
    raise ValueError(f"Unsupported A type: {type(a)}")
  if k != k2:
    raise ValueError(
        "WGMMA requires A and B to have the same contraction dimension (K),"
        f" got: {k2} and {k}"
    )
  if element_type != element_type2:
    raise ValueError(
        "WGMMA requires A and B to have the same element type, got:"
        f" {element_type2} and {element_type}"
    )
  if acc._value.shape != (m, n):
    raise ValueError(
        f"Accumulator shape mismatch: expected {(m, n)}, got {acc._value.shape}"
    )
  if element_type == f32 or element_type == ir.BF16Type.get():
    if acc._value.mlir_dtype != f32:
      raise ValueError(
          f"WGMMA with element type {element_type} only supports accumulators"
          f" of type f32, but got: {acc._value.mlir_dtype}"
      )
  elif any(
      t.isinstance(element_type)
      for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType}
  ):
    if acc._value.mlir_dtype != f16 and acc._value.mlir_dtype != f32:
      raise ValueError(
          f"WGMMA with element type {element_type} only supports accumulators "
          f"of type f32 or f16, but got: {acc._value.mlir_dtype}"
      )
  elif element_type == i8:
    if a_in_regs and not a.is_signed:
      raise NotImplementedError("WGMMA with lhs of type u8")
    if acc._value.mlir_dtype != i32 or not acc._value.is_signed:
      raise ValueError(
          f"WGMMA with element type {element_type} only supports accumulators "
          f"of type s32, but got: {acc._value.mlir_dtype}"
      )
  else:
    raise NotImplementedError(f"Unsupported element type: {element_type}")

  # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
  # instructions must be issued in groups of the same width as the swizzle.
  m_group_elems = 64  # Hopper has a fixed M instruction shape.
  k_group_elems = swizzle // utils.bytewidth(element_type)
  if n > 256 or n % 8:
    raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
  n_group_elems = n  # We assume only one N group below.
  if m % m_group_elems:
    raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
  if k % k_group_elems:
    raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
  m_groups = m // m_group_elems
  k_groups = k // k_group_elems
  # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
  wgmma_element_type = (
      ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
  )

  # Step 3. Compute the operand descriptors.
  if a_in_regs:
    a_desc_base = a_m_group_stride = a_k_group_stride = None
    a_instr_params = dict(a_transpose=None, a_k_stride=None)
  else:
    (
        (a_desc_base, a_k_instr_stride),
        (a_m_group_stride, a_k_group_stride),
        a_fastest,
    ) = mma_utils.create_descriptor(
        a,
        swizzle=swizzle,
        large_tile=(m_group_elems, k_group_elems),
        group_size=(m_group_elems, k_group_elems),
        logical_k_major=False,
    )
    assert not a_k_instr_stride[0]  # We'd need separate a/b swizzles.
    a_k_instr_stride = a_k_instr_stride[1][0]
    a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
                          a_k_stride=a_k_instr_stride)
  (
      (b_desc_base, b_k_instr_stride),
      (b_n_group_stride, b_k_group_stride),
      b_fastest,
  ) = mma_utils.create_descriptor(
      b,
      swizzle=swizzle,
      large_tile=(k_group_elems,) * 2,  # It's not a typo that we use k for n.
      group_size=(k_group_elems, n_group_elems),
      logical_k_major=True,
  )
  assert not b_k_instr_stride[0]  # We'd need separate a/b swizzles.
  b_k_instr_stride = b_k_instr_stride[1][0]
  del b_n_group_stride  # We only support one N group.

  # Step 4. Issue the instructions.
  if a_in_regs:
    a = wgmma_fence(a)  # Make sure the registers are ready.

  i64 = ir.IntegerType.get_signless(64)
  new_acc_regs = acc._value.registers.copy()
  for mi in range(m_groups):
    for ki in range(k_groups):
      if a_in_regs:
        a_mk = a[
            mi * m_group_elems : (mi + 1) * m_group_elems,
            ki * k_group_elems : (ki + 1) * k_group_elems,
        ]
      else:
        assert a_m_group_stride is not None and a_k_group_stride is not None
        a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
        a_mk = _llvm_add(
            a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
        )
      b_k = _llvm_add(
          b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
      )
      new_acc_regs[mi : mi + 1] = wgmma_m64(
          new_acc_regs[mi : mi + 1],
          a_mk,
          b_k,
          swizzle=swizzle,
          n=n_group_elems,
          element_type=wgmma_element_type,
          b_transpose=b_fastest != mma_utils.Dim.K,
          b_k_stride=b_k_instr_stride,
          **a_instr_params,
      )
  return WGMMAAccumulator(
      _value=fa.FragmentedArray(
          _registers=new_acc_regs,
          _layout=acc._value.layout,
          _is_signed=acc._value.is_signed,
      ),
      _original_layout=acc._original_layout,
      _sync=False,
  )


def wgmma_fence(array: fa.FragmentedArray) -> fa.FragmentedArray:
  """Fences the array construction from WGMMA instructions.

  LLVM treats in-register computation as pure and can move it after the fence,
  which is explicitly disallowed by the PTX programming model. For that reason,
  we insert an LLVM optimization barrier before the fence.
  """
  array = fa.optimization_barrier(array)
  nvvm.wgmma_fence_aligned()
  return array


def _as_i32_reg(v):
  i32 = ir.IntegerType.get_signless(32)
  return llvm.extractelement(
      vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
  )


def _lc(x):
  i32 = ir.IntegerType.get_signless(32)
  return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result


def _llvm_add(x, y):
  return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)


def _unpack_i32(vec_ty, r):
  i32 = ir.IntegerType.get_signless(32)
  return vector.bitcast(
      vec_ty, vector.broadcast(ir.VectorType.get((1,), i32), r)
  )
